import enum
import base64
import requests
import tempfile
import os
from clogger import logger
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.filters import SearchFilter, OrderingFilter
from rest_framework import mixins
from django.http.response import Http404
from django.conf import settings
from apps.task import seriaizer
from apps.task.models import JobModel
from apps.task.filter import TaskFilter, IsOwnerFilterBackend
from lib.base_view import CommonModelViewSet
from lib.response import success, not_found, ErrorResponse, other_response
from lib.authentications import TokenAuthentication
from service_scripts.base import FileItem
from asgiref.sync import async_to_sync
from .helper import DiagnosisHelper
from service_scripts.base import (
    DiagnosisJobResult,
    DiagnosisTaskResult,
)


class ContextType(enum.Enum):
    TEXT = 'text'
    BASE64 = 'base64'


class TaskAPIView(
    CommonModelViewSet,
    mixins.ListModelMixin,
    mixins.RetrieveModelMixin,
    mixins.DestroyModelMixin,
    mixins.CreateModelMixin,
):
    queryset = JobModel.objects.all().order_by("-created_at")
    serializer_class = seriaizer.JobListSerializer
    filter_backends = (
        IsOwnerFilterBackend,
        DjangoFilterBackend,
        SearchFilter,
        OrderingFilter,
    )
    search_fields = ("id", "task_id", "created_by__id", "status", "params")  # 模糊查询
    filterset_class = TaskFilter  # 精确查询
    authentication_classes = [TokenAuthentication]
    create_requird_fields = ["service_name"]
    lookup_field = "task_id"

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def get_authenticators(self):
        # 判断请求是否是单查task
        task_id = self.kwargs.get(self.lookup_field, None)
        if self.request.path.endswith("health_check/") or (
            task_id is not None and self.request.method == "GET"
        ):
            return []
        else:
            return [auth() for auth in self.authentication_classes]

    def create(self, request, *args, **kwargs):
        return self.create_task_v2(request, *args, **kwargs)

    def retrieve(self, request, *args, **kwargs):
        try:
            instance = self.get_object()
        except Http404:
            return other_response(result={}, message="task不存在", success=False, code=400)

        response = seriaizer.JobRetrieveSerializer(instance)
        res = response.data
        # result = res["result"]
        # if "state" in result:
        #     res["result"] = result["result"]
        result_filter = request.GET.get("result_filter", None)
        if result_filter is not None:
            final_result = {}
            for k, v in res["result"].items():
                if k in result_filter:
                    final_result[k] = v
            res["result"] = final_result
        res["url"] = "/".join(["", "diagnose", "detail", instance.task_id])
        return success(result=res)

    def list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset())
        if not queryset:
            return success([], total=0)
        return super(TaskAPIView, self).list(request, *args, **kwargs)

    def destroy(self, request, *args, **kwargs):
        instance = self.get_queryset().filter(**kwargs).first()
        if not instance:
            return not_found()
        self.perform_destroy(instance)
        return success(message="删除成功", code=200, result={})

    def create_task_v2(self, request, *args, **kwargs):
        try:
            # 1. Check required params
            res = self.require_param_validate(request, ["service_name"])
            if not res["success"]:
                return ErrorResponse(msg=res.get("message", "Missing parameters"))
            data = request.data

            # 3. Create Task
            instance = DiagnosisHelper.init(data, getattr(request, "user"))
            self.produce_event_to_cec(
                settings.SYSOM_CEC_DIAGNOSIS_TASK_DISPATCH_TOPIC,
                {"task_id": instance.task_id},
            )
            return success({"task_id": instance.task_id})
        except Exception as e:
            logger.exception(e)
            return ErrorResponse(msg=str(e))

    def sbs_task_create(self, request, *args, **kwargs):
        """Create step by step diagnosis task

        Just create task, invoke preprocess script, then return preprocess result

        Args:
            request (_type_): _description_
        """
        try:
            # 1. Check required params
            res = self.require_param_validate(request, ["service_name", "params"])
            if not res["success"]:
                return ErrorResponse(msg=res.get("message", "Missing parameters"))
            data = request.data
            params = {
                "service_name": data["service_name"],
                **data["params"],
            }
            if "channel" not in params:
                params["channel"] = "offline"

            # 2. Create Task
            instance = DiagnosisHelper.init(params, getattr(request, "user"))

            # 3. Invoke preprocess script
            diagnosis_task = DiagnosisHelper.preprocess(instance, True)
            response = seriaizer.JobRetrieveSerializer(instance)
            self.produce_event_to_cec(
                settings.SYSOM_CEC_DIAGNOSIS_TASK_CREATED, response.data
            )
            if diagnosis_task is None:
                return ErrorResponse(
                    msg=f"Preprocess script invoke error: {instance.err_msg}"
                )
            return success({"task_id": instance.task_id, **diagnosis_task.to_dict()})
        except Exception as e:
            logger.exception(e)
            return ErrorResponse(msg=str(e))

    def sbs_task_result(self, request, *args, **kwargs):
        """Upload step by step diagnosis result

        Args:
            request (_type_): _description_
        """
        try:
            obj_list = request.FILES.getlist("files")
            task_id = request.POST.get("task_id", None)
            brief = request.POST.get("brief", False)
            content_encoding = request.POST.get("content_encoding", "text")
            results = request.POST.getlist("results", None)
            try:
                content_encoding = ContextType(content_encoding)
            except ValueError:
                return ErrorResponse("content_encoding field can only `text` or `base64`!")

            if task_id is None or results is None:
                return ErrorResponse(
                    f"Missing params, required both <task_id> and <result>"
                )

            if content_encoding.value == 'base64':
                results = [
                    base64.b64decode(result).decode() for result in results
                ]

            # 1. Get task
            instance = JobModel.objects.get(task_id=task_id)
            if instance is None:
                return ErrorResponse(f"No such diagnosis task with id = {task_id}")
            if instance.status not in ["Ready", "Running"]:
                return ErrorResponse(
                    f"Target diganosis task is finished, current status = {instance.status}"
                )

            with tempfile.TemporaryDirectory() as tmp_dir:
                # 2. File items
                file_list = []
                for obj in obj_list:
                    local_path = os.path.join(tmp_dir, obj.name)
                    file_item = FileItem(
                        name=obj.name, remote_path="", local_path=local_path
                    )
                    with open(local_path, "wb") as f:
                        for chunk in obj.chunks():
                            f.write(chunk)
                    file_list.append(file_item)

                # 3. Build diagnosis task result
                job_result = DiagnosisTaskResult(
                    0,
                    job_results=[
                        DiagnosisJobResult(
                            0, stdout=result, job=None, file_list=file_list
                        )
                        for result in results
                    ],
                    in_order=False,
                )

                # 4. Invoke postprocess script
                DiagnosisHelper.postprocess(instance, job_result)
                if brief:
                    return success("")
                response = seriaizer.JobRetrieveSerializer(instance)
                return success(response.data)
        except Exception as e:
            logger.exception(e)
            return ErrorResponse(msg=str(e))

    def task_hook(self, request, *args, **kwargs):
        """Invoke task hook

        Args:
            request (_type_): _description_
        """
        try:
            # 1. Check required params
            res = self.require_param_validate(request, ["task_id", "params"])
            if not res["success"]:
                return ErrorResponse(msg=res.get("message", "Missing parameters"))
            data = request.data
            task_id = data.get("task_id", None)
            params = data.get("params", {})
            instance = JobModel.objects.get(task_id=task_id)
            res = DiagnosisHelper.invoke_diagnosis_hook(instance, params)
            if res.code == 200:
                return success(res.data)
            else:
                return ErrorResponse(msg=res.err_msg)
        except Exception as e:
            logger.exception(e)
            return ErrorResponse(msg=str(e))

    def offline_import(self, request, *args, **kwargs):
        """Offline import of diagnosis logs"""
        try:
            # 1. Check required params
            res = self.require_param_validate(
                request, ["instance", "offline_log", "service_name"]
            )
            if not res["success"]:
                return ErrorResponse(msg=res.get("message", "Missing parameters"))
            data = request.data

            # 2. Offline import
            offline_log = data.pop("offline_log", "")
            instance = DiagnosisHelper.offline_import(data, getattr(request, "user"))

            # 3. postprocess
            async_to_sync(DiagnosisHelper.postprocess_async)(
                instance,
                diagnosis_task_result=DiagnosisTaskResult(
                    0, job_results=[DiagnosisJobResult(0, stdout=offline_log)]
                ),
            )
            return success({"task_id": instance.task_id})
        except Exception as e:
            logger.exception(e)
            return ErrorResponse(msg=str(e))

    def health_check(self, request, *args, **kwargs):
        return success(result={})

    def get_host(self, request):
        host_url = f"{settings.SYSOM_API_URL}/api/v1/host/"
        res = requests.get(host_url)
        if res.status_code == 200:
            return success(result=res.json().get("data", []))
        else:
            return ErrorResponse(msg=f"Get host failed, status_code={res.status_code}")
